{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Spatial Adaptive Graph Neural Network for POI Graph Learning \n",
    "In this tuorial, we will go through how to run the Spatial Adaptive Graph Neural Network (SA-GNN) to learn on the POI graph. If you are intersted in more details, please refer to the paper \"Competitive analysis for points of interest\".\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CPUPlace"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import pgl\n",
    "import pickle\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from random import shuffle\n",
    "import paddle\n",
    "import paddle.nn as nn\n",
    "import paddle.nn.functional as F\n",
    "from sagnn import SpatialOrientedAGG, SpatialAttnProp\n",
    "paddle.set_device('cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load dataset and construct the POI graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_dataset(file_path, dataset):\n",
    "    \"\"\"\n",
    "    1) Load the POI dataset from four files: edges file, two-dimensional coordinate file, POI feature file and label file.\n",
    "    2) Construct the PGL-based graph and return the pgl.Graph instance.\n",
    "    \"\"\"\n",
    "    edges = pd.read_table(os.path.join(file_path, '%s.edge' % dataset), header=None, sep=' ')\n",
    "    # ex, ey = np.array(edges[:][0]), np.array(edges[:][1])\n",
    "    edges = list(zip(edges[:][0],edges[:][1]))\n",
    "    coords = pd.read_table(os.path.join(file_path, '%s.coord' % dataset), header=None, sep=' ')\n",
    "    coords = np.array(coords)\n",
    "\n",
    "    feat_path = os.path.join(file_path, '%s.feat' % dataset) # pickle file\n",
    "    if os.path.exists(feat_path):\n",
    "        with open(feat_path, 'rb') as f:\n",
    "            features = pickle.load(f)\n",
    "    else:\n",
    "        features = np.eye(len(coords))\n",
    "\n",
    "    graph = pgl.Graph(edges, num_nodes=len(coords), node_feat={\"feat\": features, 'coord': coords})\n",
    "    ind_labels = pd.read_table(os.path.join(file_path, '%s.label' % dataset), header=None, sep=' ')\n",
    "    inds_1 = np.array(ind_labels)[:,0]\n",
    "    inds_2 = np.array(ind_labels)[:,1]\n",
    "    labels = np.array(ind_labels)[:,2:]\n",
    "    \n",
    "    return graph, (inds_1, inds_2), labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build the SA-GNN model for link prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DenseLayer(nn.Layer):\n",
    "    def __init__(self, in_dim, out_dim, activation=F.relu, bias=True):\n",
    "        super(DenseLayer, self).__init__()\n",
    "        self.activation = activation\n",
    "        if not bias:\n",
    "            self.fc = nn.Linear(in_dim, out_dim, bias_attr=False)\n",
    "        else:\n",
    "            self.fc = nn.Linear(in_dim, out_dim)\n",
    "    \n",
    "    def forward(self, input_feat):\n",
    "        return self.activation(self.fc(input_feat))\n",
    "\n",
    "class SAGNNModel(nn.Layer):\n",
    "    def __init__(self, infeat_dim, hidden_dim=128, dense_dims=[128,128], num_heads=4, feat_drop=0.2, num_sectors=4, max_dist=0.1, grid_len=0.001, num_convs=1):\n",
    "        super(SAGNNModel, self).__init__()\n",
    "        self.num_convs = num_convs\n",
    "        self.agg_layers = nn.LayerList()\n",
    "        self.prop_layers = nn.LayerList()\n",
    "        in_dim = infeat_dim\n",
    "        for i in range(num_convs):\n",
    "            agg = SpatialOrientedAGG(in_dim, hidden_dim, num_sectors, transform=False, activation=None)\n",
    "            prop = SpatialAttnProp(hidden_dim, hidden_dim, num_heads, feat_drop, max_dist, grid_len, activation=None)\n",
    "            self.agg_layers.append(agg)\n",
    "            self.prop_layers.append(prop)\n",
    "            in_dim = num_heads * hidden_dim\n",
    "\n",
    "        self.mlp = nn.LayerList()\n",
    "        for hidden_dim in dense_dims:\n",
    "            self.mlp.append(DenseLayer(in_dim, hidden_dim, activation=F.relu))\n",
    "            in_dim = hidden_dim\n",
    "        self.output_layer = nn.Linear(2 * in_dim, 1)\n",
    "    \n",
    "    def forward(self, graph, inds):\n",
    "        feat_h = graph.node_feat['feat']\n",
    "        for i in range(self.num_convs):\n",
    "            feat_h = self.agg_layers[i](graph, feat_h)\n",
    "            graph = graph.tensor()\n",
    "            feat_h = self.prop_layers[i](graph, feat_h)\n",
    "\n",
    "        for fc in self.mlp:\n",
    "            feat_h = fc(feat_h)\n",
    "        feat_h = paddle.concat([paddle.gather(feat_h, inds[0]), paddle.gather(feat_h, inds[1])], axis=-1)\n",
    "        output = F.sigmoid(self.output_layer(feat_h))\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we load a mock dataset for demonstration, you can load the full dataset as you want.\n",
    "\n",
    "Note that all needed files should include:\n",
    "1) one edge file (dataset.edge) for POI graph construction;\n",
    "2) one coordinate file (dataset.coord) for POI position;\n",
    "3) one label file (dataset.label) for training model;\n",
    "4) one feature file (dataset.feat) for POI feature loading, which is optional. If there is no dataset.feat, the default POI feature is the one-hot vector."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dataset num: 300 training num: 180\n"
     ]
    }
   ],
   "source": [
    "graph, inds, labels = load_dataset('./data/', 'mock_poi')\n",
    "ids = [i for i in range(len(labels))]\n",
    "shuffle(ids)\n",
    "train_num = int(0.6*len(labels))\n",
    "train_inds = (inds[0][ids[:train_num]], inds[1][ids[:train_num]])\n",
    "test_inds = (inds[0][ids[train_num:]], inds[1][ids[train_num:]])\n",
    "train_labels = labels[ids[:train_num]]\n",
    "test_labels = labels[ids[train_num:]]\n",
    "print(\"dataset num: %s\" % (len(labels)), \"training num: %s\" % (len(train_labels)))\n",
    "infeat_dim = graph.node_feat['feat'].shape[0]\n",
    "\n",
    "model = SAGNNModel(infeat_dim)\n",
    "optim = paddle.optimizer.Adam(0.001, parameters=model.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Strart training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:0 train/loss:0.5990816\n",
      "epoch:1 train/loss:0.5992603\n",
      "epoch:2 train/loss:0.5916268\n",
      "epoch:3 train/loss:0.58739626\n",
      "epoch:4 train/loss:0.56397086\n",
      "test loss: 0.63546157, test accuracy: [0.65]\n"
     ]
    }
   ],
   "source": [
    "def train(model, graph, inds, labels, optim):\n",
    "    model.train()\n",
    "    graph = graph.tensor()\n",
    "    inds = paddle.to_tensor(inds, 'int64')\n",
    "    labels = paddle.to_tensor(labels, 'float32')\n",
    "    preds = model(graph, inds)\n",
    "    bce_loss = paddle.nn.BCELoss()\n",
    "    loss = bce_loss(preds, labels)\n",
    "    loss.backward()\n",
    "    optim.step()\n",
    "    optim.clear_grad()\n",
    "    return loss.numpy()[0]\n",
    "\n",
    "def evaluate(model, graph, inds, labels):\n",
    "    model.eval()\n",
    "    graph = graph.tensor()\n",
    "    inds = paddle.to_tensor(inds, 'int64')\n",
    "    labels = paddle.to_tensor(labels, 'float32')\n",
    "    preds = model(graph, inds)\n",
    "    bce_loss = paddle.nn.BCELoss()\n",
    "    loss = bce_loss(preds, labels)\n",
    "    return loss.numpy()[0],  1.0*np.sum(preds.numpy().astype(int) == labels.numpy().astype(int), axis=0) / len(labels)\n",
    "\n",
    "for epoch_id in range(5):\n",
    "    train_loss = train(model, graph, train_inds, train_labels, optim)\n",
    "    print(\"epoch:%d train/loss:%s\" % (epoch_id, train_loss))\n",
    "\n",
    "test_loss, test_acc = evaluate(model, graph, test_inds, test_labels)\n",
    "print(\"test loss: %s, test accuracy: %s\" % (test_loss, test_acc))\n"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "4702ab958f6790120ca332725aeb1ce0c0586ea4908ac0580a3c490ed2402bb2"
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit ('paddle': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
